{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train\n",
    "# !wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse(file):\n",
    "    with open(file) as fopen:\n",
    "        texts = fopen.read().split('\\n')\n",
    "    left, right = [], []\n",
    "    for text in texts:\n",
    "        if '-DOCSTART-' in text or not len(text):\n",
    "            continue\n",
    "        splitted = text.split()\n",
    "        left.append(splitted[0])\n",
    "        right.append(splitted[1])\n",
    "    return left, right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "left_train, right_train = parse('eng.train')\n",
    "left_test, right_test = parse('eng.testa')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_len = 50\n",
    "def iter_seq(x):\n",
    "    return np.array([x[i: i+seq_len] for i in range(0, len(x)-seq_len, 1)])\n",
    "\n",
    "def to_train_seq(*args):\n",
    "    return [iter_seq(x) for x in args]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "left_train, right_train = to_train_seq(left_train, right_train)\n",
    "left_test, right_test = to_train_seq(left_test, right_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'PAD': 0,\n",
       " '\"': 1,\n",
       " '$': 2,\n",
       " \"''\": 3,\n",
       " '(': 4,\n",
       " ')': 5,\n",
       " ',': 6,\n",
       " '.': 7,\n",
       " ':': 8,\n",
       " 'CC': 9,\n",
       " 'CD': 10,\n",
       " 'DT': 11,\n",
       " 'EX': 12,\n",
       " 'FW': 13,\n",
       " 'IN': 14,\n",
       " 'JJ': 15,\n",
       " 'JJR': 16,\n",
       " 'JJS': 17,\n",
       " 'LS': 18,\n",
       " 'MD': 19,\n",
       " 'NN': 20,\n",
       " 'NNP': 21,\n",
       " 'NNPS': 22,\n",
       " 'NNS': 23,\n",
       " 'NN|SYM': 24,\n",
       " 'PDT': 25,\n",
       " 'POS': 26,\n",
       " 'PRP': 27,\n",
       " 'PRP$': 28,\n",
       " 'RB': 29,\n",
       " 'RBR': 30,\n",
       " 'RBS': 31,\n",
       " 'RP': 32,\n",
       " 'SYM': 33,\n",
       " 'TO': 34,\n",
       " 'UH': 35,\n",
       " 'VB': 36,\n",
       " 'VBD': 37,\n",
       " 'VBG': 38,\n",
       " 'VBN': 39,\n",
       " 'VBP': 40,\n",
       " 'VBZ': 41,\n",
       " 'WDT': 42,\n",
       " 'WP': 43,\n",
       " 'WP$': 44,\n",
       " 'WRB': 45}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tag2idx = {'PAD': 0}\n",
    "for no, u in enumerate(np.unique(right_train)):\n",
    "    tag2idx[u] = no + 1\n",
    "tag2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['EU', 'rejects', 'German', ..., 'the', 'European', 'Union'],\n",
       "       ['rejects', 'German', 'call', ..., 'European', 'Union', \"'s\"],\n",
       "       ['German', 'call', 'to', ..., 'Union', \"'s\", 'veterinary'],\n",
       "       ...,\n",
       "       ['Peter', 'Hedblom', '(', ..., 'Division', 'three', 'Swansea'],\n",
       "       ['Hedblom', '(', 'Sweden', ..., 'three', 'Swansea', '1'],\n",
       "       ['(', 'Sweden', ')', ..., 'Swansea', '1', 'Lincoln']], dtype='<U61')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "left_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
    "BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
    "BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling\n",
    "\n",
    "tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file=BERT_VOCAB, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def XY(left_train, right_train):\n",
    "    X, Y = [], []\n",
    "    for i in tqdm(range(len(left_train))):\n",
    "        left = left_train[i]\n",
    "        right = right_train[i]\n",
    "        bert_tokens = ['[CLS]']\n",
    "        y = ['PAD']\n",
    "        for no, orig_token in enumerate(left):\n",
    "            y.append(right[no])\n",
    "            t = tokenizer.tokenize(orig_token)\n",
    "            bert_tokens.extend(t)\n",
    "            y.extend(['PAD'] * (len(t) - 1))\n",
    "        bert_tokens.append(\"[SEP]\")\n",
    "        y.append('PAD')\n",
    "        X.append(tokenizer.convert_tokens_to_ids(bert_tokens))\n",
    "        Y.append([tag2idx[i] for i in y])\n",
    "    return X, Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 203571/203571 [03:32<00:00, 959.85it/s] \n"
     ]
    }
   ],
   "source": [
    "train_X, train_Y = XY(left_train, right_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "train_X = keras.preprocessing.sequence.pad_sequences(train_X, padding='post')\n",
    "train_Y = keras.preprocessing.sequence.pad_sequences(train_Y, padding='post')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 51312/51312 [00:52<00:00, 970.83it/s] \n"
     ]
    }
   ],
   "source": [
    "test_X, test_Y = XY(left_test, right_test)\n",
    "test_X = keras.preprocessing.sequence.pad_sequences(test_X, padding='post')\n",
    "test_Y = keras.preprocessing.sequence.pad_sequences(test_Y, padding='post')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 3\n",
    "batch_size = 16\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(train_X) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.maxlen = tf.shape(self.X)[1]\n",
    "        self.lengths = tf.count_nonzero(self.X, 1)\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            config=bert_config,\n",
    "            is_training=True,\n",
    "            input_ids=self.X,\n",
    "            use_one_hot_embeddings=False)\n",
    "        output_layer = model.get_sequence_output()\n",
    "        logits = tf.layers.dense(output_layer, dimension_output)\n",
    "        y_t = self.Y\n",
    "        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(\n",
    "            logits, y_t, self.lengths\n",
    "        )\n",
    "        self.cost = tf.reduce_mean(-log_likelihood)\n",
    "        self.optimizer = tf.train.AdamOptimizer(\n",
    "            learning_rate = learning_rate\n",
    "        ).minimize(self.cost)\n",
    "        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)\n",
    "        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(\n",
    "            logits, transition_params, self.lengths\n",
    "        )\n",
    "        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')\n",
    "\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(self.tags_seq, mask)\n",
    "        mask_label = tf.boolean_mask(y_t, mask)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "\n",
      "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.6/site-packages/bert/modeling.py:358: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.6/site-packages/bert/modeling.py:671: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/crf/python/ops/crf.py:213: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py:626: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "dimension_output = len(tag2idx)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:  41%|████      | 5169/12724 [28:24<40:33,  3.10it/s, accuracy=0.966, cost=8.39]  IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  65%|██████▌   | 8302/12724 [45:35<23:57,  3.08it/s, accuracy=0.91, cost=19.6]  IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  92%|█████████▏| 11698/12724 [1:04:08<05:45,  2.97it/s, accuracy=0.921, cost=21.6] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "test minibatch loop:  64%|██████▍   | 2065/3207 [04:18<02:21,  8.06it/s, accuracy=0.927, cost=13.2] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  17%|█▋        | 2225/12724 [12:11<55:21,  3.16it/s, accuracy=0.968, cost=5.82]   IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  43%|████▎     | 5506/12724 [30:01<40:11,  2.99it/s, accuracy=0.963, cost=5.82] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  68%|██████▊   | 8607/12724 [46:54<21:45,  3.15it/s, accuracy=0.994, cost=2.3]  IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  95%|█████████▍| 12075/12724 [1:05:43<03:27,  3.13it/s, accuracy=0.977, cost=4.14] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "test minibatch loop:  82%|████████▏ | 2632/3207 [05:27<01:16,  7.54it/s, accuracy=0.934, cost=18.4] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  23%|██▎       | 2915/12724 [15:57<55:07,  2.97it/s, accuracy=0.999, cost=0.531]  IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  51%|█████     | 6512/12724 [35:32<33:40,  3.07it/s, accuracy=0.998, cost=0.888]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop: 100%|██████████| 12724/12724 [1:09:14<00:00,  3.71it/s, accuracy=0.962, cost=4.34]  \n",
      "test minibatch loop: 100%|██████████| 3207/3207 [06:36<00:00,  8.28it/s, accuracy=0.929, cost=24.9] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 4551.9211502075195\n",
      "epoch: 2, training loss: 2.525351, training acc: 0.987957, valid loss: 16.345582, valid acc: 0.952310\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "for e in range(3):\n",
    "    lasttime = time.time()\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_y = train_Y[i : min(i + batch_size, train_X.shape[0])]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'test minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    train_loss /= len(train_X) / batch_size\n",
    "    train_acc /= len(train_X) / batch_size\n",
    "    test_loss /= len(test_X) / batch_size\n",
    "    test_acc /= len(test_X) / batch_size\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (e, train_loss, train_acc, test_loss, test_acc)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx2tag = {i: w for w, i in tag2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred2label(pred):\n",
    "    out = []\n",
    "    for pred_i in pred:\n",
    "        out_i = []\n",
    "        for p in pred_i:\n",
    "            out_i.append(idx2tag[p])\n",
    "        out.append(out_i)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 3207/3207 [06:20<00:00,  8.55it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "    batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "    predicted = pred2label(sess.run(model.tags_seq,\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "            },\n",
    "    ))\n",
    "    real = pred2label(batch_y)\n",
    "    predict_Y.extend(predicted)\n",
    "    real_Y.extend(real)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "          \"       1.00      1.00      1.00     32000\n",
      "          $       0.99      0.99      0.99      5050\n",
      "         ''       0.99      0.67      0.80       550\n",
      "          (       1.00      1.00      1.00     33950\n",
      "          )       1.00      1.00      1.00     34000\n",
      "          ,       1.00      1.00      1.00     97450\n",
      "          .       1.00      1.00      1.00     93842\n",
      "          :       0.94      0.97      0.96     31105\n",
      "         CC       1.00      1.00      1.00     46551\n",
      "         CD       0.95      0.98      0.97    214444\n",
      "         DT       1.00      0.99      0.99    175916\n",
      "         EX       0.90      0.98      0.94      1998\n",
      "         FW       0.42      0.30      0.35      1450\n",
      "         IN       0.99      0.98      0.98    248430\n",
      "         JJ       0.86      0.86      0.86    152092\n",
      "        JJR       0.80      0.82      0.81      5250\n",
      "        JJS       0.95      0.89      0.92      3900\n",
      "         LS       0.00      0.00      0.00        50\n",
      "         MD       0.99      0.99      0.99     15000\n",
      "         NN       0.94      0.86      0.90    300941\n",
      "        NNP       0.89      0.94      0.92    427200\n",
      "       NNPS       0.68      0.52      0.59      8300\n",
      "        NNS       0.96      0.94      0.95    125063\n",
      "     NN|SYM       0.46      0.12      0.19        50\n",
      "        PAD       1.00      1.00      1.00   4104960\n",
      "        PDT       0.30      0.89      0.45       350\n",
      "        POS       0.99      0.99      0.99     21150\n",
      "        PRP       1.00      0.99      0.99     43100\n",
      "       PRP$       0.98      1.00      0.99     21086\n",
      "         RB       0.88      0.91      0.89     49446\n",
      "        RBR       0.68      0.54      0.60      2650\n",
      "        RBS       0.72      0.82      0.77       900\n",
      "         RP       0.74      0.77      0.75      7500\n",
      "        SYM       0.96      0.92      0.94      4300\n",
      "         TO       1.00      1.00      1.00     45288\n",
      "         UH       0.16      0.20      0.18       250\n",
      "         VB       0.93      0.89      0.91     55934\n",
      "        VBD       0.93      0.96      0.95    111414\n",
      "        VBG       0.87      0.91      0.89     35000\n",
      "        VBN       0.91      0.82      0.86     49650\n",
      "        VBP       0.85      0.87      0.86     18250\n",
      "        VBZ       0.91      0.92      0.91     25550\n",
      "        WDT       0.92      0.93      0.92      7750\n",
      "         WP       1.00      0.98      0.99      6350\n",
      "        WP$       1.00      1.00      1.00       450\n",
      "        WRB       1.00      1.00      1.00      4650\n",
      "\n",
      "avg / total       0.98      0.98      0.98   6670560\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(np.array(real_Y).ravel(), np.array(predict_Y).ravel()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
